import copy
from typing import Union, Tuple

import numpy as np
from scipy.stats import sem
from tqdm import trange

from datasets.common import SplitDataset
from src.logistic_regression import LogisticRegressionModel
from src.logistic_regression import RegularizationType, LogisticRegression


def single_monte_carlo_shapley_estimate(
        dataset: SplitDataset,
        permutation: np.ndarray,
        regularization: float = 0.0,
        reg_type: Union[RegularizationType, str] = RegularizationType.L2,
        use_tqdm: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Performs 1 run of the Permutation Sampling Monte-Carlo estimation of Shapley values.

    Returns:
        - phi: Estimated Shapley values for this permutation.
        - weight_history: Array of model weights after each training step.
    """
    n, d = dataset.train.features.shape
    model = LogisticRegressionModel(np.zeros(d), regularization, reg_type)
    phi = np.zeros(n)
    weight_history = np.zeros((n, d))

    previous_loss = None

    for i in trange(n, desc="Single Monte-Carlo Run", disable=not use_tqdm):
        subset_indices = permutation[:i + 1]
        regression = LogisticRegression(
            dataset.train.features[subset_indices],
            dataset.train.labels[subset_indices]
        )
        regression.model = model
        regression.fit(warm_start=True, max_evals=15)

        current_loss = np.sum(regression.model.get_model_losses(
            dataset.test.features,
            dataset.test.labels
        ))

        if previous_loss is not None:
            marginal_contribution = previous_loss - current_loss
            phi[permutation[i]] = marginal_contribution
        else:
            phi[permutation[i]] = 0.0

        previous_loss = current_loss
        model = regression.model

        weight_history[i] = model.weights.copy()

    return phi, weight_history


def monte_carlo_shapley_estimates(
        dataset: SplitDataset,
        k: int = 10,
        regularization: float = 0.0,
        reg_type: Union[RegularizationType, str] = RegularizationType.L2,
        seed: int = 42,
        verbosity: int = 2
) -> np.ndarray:
    """
    Runs k Monte Carlo simulations to estimate Shapley values using permutation sampling.

    Args:
        dataset: A SplitDataset object including a train and test set.
        k: Number of Monte Carlo runs.
        regularization: Regularization strength.
        reg_type: Type of regularization (L1, L2).
        seed: Random seed for reproducibility.
        verbosity: Verbosity level (0 = silent, 1 = print before each run, 2 = print + tqdm).

    Returns:
        A (k x n) array of Shapley value estimates from each Monte Carlo run.
    """
    np.random.seed(seed)
    n = dataset.train.features.shape[0]
    all_phi = np.zeros((k, n))

    for run in range(k):
        if verbosity >= 1:
            print(f"Running Monte Carlo simulation {run + 1}/{k}...")

        permutation = np.random.permutation(n)
        phi = single_monte_carlo_shapley_estimate(
            dataset,
            permutation,
            regularization=regularization,
            reg_type=reg_type,
            use_tqdm=(verbosity >= 2)
        )
        all_phi[run] = phi

    return all_phi


def estimate_normalized_shapley_with_error(runs: np.ndarray):
    """
    Computes normalized Shapley value estimates and error bars using Monte Carlo estimates.

    Args:
        runs: (k x n) array of Monte Carlo Shapley value estimates (k runs, n data points)

    Returns:
        norm_phi: Normalized Shapley value estimates (n,)
        norm_sem: Estimated standard error of the normalized Shapley values (n,)
    """
    # Empirical means across runs
    mu = np.mean(runs, axis=0)  # shape (n,)
    sem_mu = sem(runs, axis=0)  # shape (n,)

    # Total contribution and its SEM
    total_mu = np.sum(mu)
    total_per_run = np.sum(runs, axis=1)  # shape (k,)
    sem_total = sem(total_per_run)  # scalar

    # Normalized Shapley values
    norm_phi = mu / total_mu

    # Error propagation: d(f) = sqrt((∂f/∂mu)^2 * var(mu) + (∂f/∂total)^2 * var(total))
    norm_sem = np.sqrt(
        (sem_mu / total_mu) ** 2 +
        ((mu * sem_total) / (total_mu ** 2)) ** 2
    )

    return norm_phi, norm_sem


def influence_shapley_estimate(
        dataset: SplitDataset,
        model: LogisticRegressionModel,
        influences: np.ndarray
) -> np.ndarray:
    """
    Estimates normalized Shapley values using influence functions as a fast approximation to leave-one-out retraining.

    Args:
        dataset: A SplitDataset object with train/test sets.
        model: A trained LogisticRegressionModel on the full dataset.
        influences: An (n x d) array of influence function estimates.
                    Each row is the vector to add to the weights to simulate removing that point.

    Returns:
        A (n,) numpy array of normalized Shapley values.
    """
    n, d = dataset.train.features.shape

    # Step 1: Compute original loss on the test set
    original_loss = np.sum(model.get_model_losses(
        dataset.test.features,
        dataset.test.labels
    ))

    shapley_values = np.zeros(n)

    for i in range(n):
        # Step 2: Create a copy of the model and perturb weights
        loo_model = copy.deepcopy(model)
        loo_model.weights = model.weights + influences[i]

        # Step 3: Compute loss of perturbed model
        perturbed_loss = np.sum(loo_model.get_model_losses(
            dataset.test.features,
            dataset.test.labels
        ))

        # Step 4: Store marginal loss difference
        shapley_values[i] = original_loss - perturbed_loss

    # Step 5: Normalize
    total = np.sum(shapley_values)
    if total != 0:
        return shapley_values / total
    else:
        # Avoid division by zero, return zeros (or optionally, uniform)
        return np.zeros(n)
